{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "import jieba\n",
    "import os\n",
    "import pickle   #持久化\n",
    "from numpy import *\n",
    "from sklearn import feature_extraction\n",
    "from sklearn.feature_extraction.text import TfidfTransformer #TF-IDF向量转换类\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer #TF_IDF向量生成类\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.feature_extraction import DictVectorizer \n",
    "from sklearn.datasets.base import Bunch\n",
    "from sklearn.naive_bayes import MultinomialNB #多项式贝叶斯算法\n",
    "import gc #解決memory error問題\n",
    "from sklearn.utils import shuffle\n",
    "import random\n",
    "import gensim\n",
    "from gensim.models import word2vec\n",
    "import pickle\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import svm\n",
    "from sklearn.metrics import classification_report\n",
    "\n",
    "from sklearn.preprocessing import LabelEncoder,OneHotEncoder\n",
    "from tensorflow.keras.layers import LSTM, Activation, Dense, Dropout, Input, Embedding\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.optimizers import RMSprop\n",
    "from tensorflow.keras.callbacks import EarlyStopping\n",
    "from keras.utils import np_utils\n",
    "from tensorflow.keras import Sequential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('toutiao_cat_data.txt', 'r', encoding='utf8') as f:\n",
    "    data = f.read().splitlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "label = []\n",
    "title = []\n",
    "keyword = []\n",
    "\n",
    "for i in range(len(data)):\n",
    "    label.append(data[i].split('_!_')[1])\n",
    "    title.append(data[i].split('_!_')[3])\n",
    "    keyword.append(data[i].split('_!_')[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>title</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>王者荣耀到后期了，“泣血之刃”好还是“制裁之刃”好？</td>\n",
       "      <td>116</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>龙文4家企业入选省重点上市后备企业名单，看看有没有你熟悉的？</td>\n",
       "      <td>104</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>带金毛坐拖拉机，这货一脸不乐意，网友：没让你耕地就不错了！</td>\n",
       "      <td>115</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>文化人最能装：《金瓶梅》刻画得确实逼真！</td>\n",
       "      <td>101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>如果奥创得逞，他能打过灭霸吗？</td>\n",
       "      <td>102</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            title label\n",
       "0      王者荣耀到后期了，“泣血之刃”好还是“制裁之刃”好？   116\n",
       "1  龙文4家企业入选省重点上市后备企业名单，看看有没有你熟悉的？   104\n",
       "2   带金毛坐拖拉机，这货一脸不乐意，网友：没让你耕地就不错了！   115\n",
       "3            文化人最能装：《金瓶梅》刻画得确实逼真！   101\n",
       "4                 如果奥创得逞，他能打过灭霸吗？   102"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keyword = pd.DataFrame(keyword)\n",
    "title = pd.DataFrame(title)\n",
    "label = pd.DataFrame(label)\n",
    "trainData = pd.concat( [title , keyword], axis=1 )\n",
    "trainData = pd.concat( [trainData , label], axis=1 )\n",
    "trainData.columns = ['title', 'keyword', 'label']\n",
    "trainData = shuffle(trainData)\n",
    "\n",
    "# trainData = trainData[:20000]\n",
    "trainData.index = range(len(trainData)) # row index從0開始\n",
    "trainData = trainData.drop(['keyword'], axis=1)\n",
    "trainData.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "55"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# #解決memory error問題，釋放記憶體\n",
    "# del label\n",
    "del title\n",
    "del keyword\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#停用詞\n",
    "stopwords=[]\n",
    "for word in open('stopwords.txt','r', encoding='utf8'):\n",
    "    stopwords.append(word.strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "utf-8\n"
     ]
    }
   ],
   "source": [
    "print(sys.getdefaultencoding())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "I0801 14:48:55.645139 12564 __init__.py:111] Building prefix dict from the default dictionary ...\n",
      "Loading model from cache C:\\Users\\jack\\AppData\\Local\\Temp\\jieba.cache\n",
      "I0801 14:48:55.718110 12564 __init__.py:131] Loading model from cache C:\\Users\\jack\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 0.924 seconds.\n",
      "I0801 14:48:56.569814 12564 __init__.py:163] Loading model cost 0.924 seconds.\n",
      "Prefix dict has been built succesfully.\n",
      "I0801 14:48:56.569814 12564 __init__.py:164] Prefix dict has been built succesfully.\n"
     ]
    }
   ],
   "source": [
    "title_file = open(\"title_split.txt\", \"w\", encoding = 'utf8')\n",
    "\n",
    "#分詞(title和keyword)\n",
    "splitTitle = []\n",
    "for i in range(len(trainData.title)):\n",
    "    split = jieba.cut(trainData.title[i], cut_all=False)\n",
    "    temp = \"\"\n",
    "    for word in split:\n",
    "        if word not in stopwords:  #若word不在停用詞裡面，則加入\n",
    "            temp = temp + word + \" \"\n",
    "    temp = temp.strip()  #刪除不必要的空白\n",
    "    \n",
    "    title_file.writelines(str(temp + \"\\n\"))   #存檔(分割詞，已扣除stopword)\n",
    "    splitTitle.append(temp)\n",
    "\n",
    "title_file.close()\n",
    "    \n",
    "# splitKeyword = []\n",
    "# for i in range(len(trainData.keyword)):\n",
    "#     split = jieba.cut(trainData.keyword[i], cut_all=False)\n",
    "#     temp = \"\"\n",
    "#     for word in split:\n",
    "#         if word not in stopwords:  #若word不在停用詞裡面，則加入\n",
    "#             temp = temp + word + \" \"   \n",
    "#     temp = temp.strip()  #刪除不必要的空白\n",
    "#     splitKeyword.append(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "197"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "del stopwords\n",
    "del split\n",
    "del temp\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TfidfTransformer + CountVectorizer  =  TfidfVectorizer\n",
    "\n",
    "#该类会将文本中的词语转换为词频矩阵，矩阵元素a[i][j] 表示j词在i类文本下的词频\n",
    "# vectorizer=CountVectorizer() \n",
    "#该类会统计每个词语的tf-idf权值\n",
    "# transformer=TfidfTransformer()   \n",
    "#第一个fit_transform是计算tf-idf，第二个fit_transform是将文本转为词频矩阵\n",
    "# tfidf=transformer.fit_transform(vectorizer.fit_transform(splitTitle))\n",
    "# word=vectorizer.get_feature_names()\n",
    "# weight=tfidf.toarray()\n",
    "\n",
    "# title_vectorizer = TfidfVectorizer(max_df = 0.5, min_df=2)\n",
    "# title_tfidf = title_vectorizer.fit_transform(splitTitle)\n",
    "# title_weight = title_tfidf.toarray()   #将tf-idf矩阵抽取出来，元素a[i][j]表示j词在i类文本中的tf-idf权重\n",
    "# title_word = title_vectorizer.get_feature_names()   #获取词袋模型中的所有词语\n",
    "\n",
    "# keyword_vectorizer = TfidfVectorizer(sublinear_tf = True, max_df = 0.5)\n",
    "# keyword_tfidf = keyword_vectorizer.fit_transform(splitTitle)\n",
    "# keyword_weight = keyword_tfidf.toarray()\n",
    "# keyword_word = keyword_vectorizer.get_feature_names()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# splitTitle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word2Vec(vocab=169098, size=300, alpha=0.025)\n"
     ]
    }
   ],
   "source": [
    "sentences = word2vec.LineSentence(\"title_split.txt\")\n",
    "model = word2vec.Word2Vec(sentences, size=300, min_count=1)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for words in sentences:\n",
    "#     print(words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda\\lib\\site-packages\\ipykernel_launcher.py:13: DeprecationWarning: Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).\n",
      "  del sys.path[0]\n",
      "D:\\Anaconda\\lib\\site-packages\\ipykernel_launcher.py:18: RuntimeWarning: invalid value encountered in true_divide\n"
     ]
    }
   ],
   "source": [
    "# 这里主要是要求得能代表文档的向量，这里就简单的将文档中的句子相加求平均，得到一个300维的文档向量\n",
    "vecs = []\n",
    "    \n",
    "for words in sentences:\n",
    "    # vec = np.zeros(2).reshape((1, 2))\n",
    "    vec = np.zeros(300).reshape((1, 300))\n",
    "    count = 0\n",
    "    # words = remove_some(words)\n",
    "    for word in words[1:]:\n",
    "        try:\n",
    "            count = count + 1\n",
    "            # vec += model[word].reshape((1, 2))\n",
    "            vec = vec + model[word].reshape((1, 300))\n",
    "#             print(vec)\n",
    "        except KeyError:\n",
    "            continue\n",
    "#     if count > 0:\n",
    "    vec = vec/count\n",
    "    vecs.append(vec)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 對dataset label 做編碼\n",
    "le = LabelEncoder()\n",
    "new_label = le.fit_transform(trainData.label).reshape(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(382688, 1)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_label.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# label list\n",
    "new_label = new_label.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # 3維轉2維\n",
    "# vecs = np.reshape(vecs, (20000,300))\n",
    "# print(vecs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = vecs.copy()\n",
    "leng = len(temp)\n",
    "index = 0\n",
    "for i in range(leng):\n",
    "    if np.isnan(temp[i]).any() == True:\n",
    "        del vecs[index]\n",
    "        del new_label[index]\n",
    "        index = index - 1\n",
    "        \n",
    "    index = index + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vecs = pd.DataFrame(vecs)\n",
    "# vecs = vecs.fillna(vecs.mean())  # 把0補成平均值\n",
    "# vecs = np.array(vecs, dtype='float64')\n",
    "# print(vecs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vecs = vecs[:,:,np.newaxis]\n",
    "# print(vecs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.isnan(vecs).any()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(382308, 15)\n"
     ]
    }
   ],
   "source": [
    "#要先轉乘np.array\n",
    "new_label = np.array(new_label, dtype='float64')\n",
    "# label encoding\n",
    "new_label = np_utils.to_categorical(new_label,15)\n",
    "print(new_label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(382308, 1, 300)\n"
     ]
    }
   ],
   "source": [
    "#要先轉乘np.array\n",
    "vecs = np.array(vecs, dtype='float64')\n",
    "print(vecs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# new_label = new_label[:,:,np.newaxis]\n",
    "# new_label = new_label[:,np.newaxis,:]\n",
    "# print(new_label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_label[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 随机选择20%作为测试集，剩余作为训练集\n",
    "X_train, X_test, y_train, y_test = train_test_split(vecs, new_label, test_size=0.20, random_state=0)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 2.58074305e-01  1.58612642e-01 -1.44940097e-01  1.29025510e-01\n",
      "   1.03847934e-01  1.28651097e-01  1.50737151e-01 -1.46670424e-01\n",
      "   8.04184579e-02  7.52745291e-02 -2.90296446e-01 -5.14505968e-02\n",
      "  -1.98715075e-01 -1.20859365e-01  2.72466362e-01 -1.25175389e-01\n",
      "   5.38247213e-02  4.45781554e-02 -6.23296700e-02 -1.06773611e-01\n",
      "   6.92078574e-02  1.52581002e-01  9.75786674e-02  3.41345274e-02\n",
      "   4.84860047e-02  7.38384965e-02 -1.42878508e-01 -7.12984825e-02\n",
      "  -1.08816757e-01 -1.19982400e-01  1.55660516e-01  3.89198248e-02\n",
      "  -9.32216638e-02 -1.87144537e-01  1.15266062e-01  2.02732196e-01\n",
      "   1.13264006e-02 -4.30892163e-02  8.55203599e-02  1.37583117e-02\n",
      "   6.00714810e-02  1.02429857e-01 -2.64415396e-01 -1.34703120e-01\n",
      "  -1.00851678e-01 -1.32131469e-01  3.03077609e-02  1.76636677e-01\n",
      "  -1.70558867e-01  1.96760715e-01  7.79409857e-02 -1.54087906e-01\n",
      "   1.04937453e-01 -1.13677219e-01 -2.26069128e-01 -4.13062514e-03\n",
      "   7.85268647e-03 -5.30659871e-02  6.19218331e-02 -1.31557771e-01\n",
      "  -2.30258452e-02 -2.50801381e-02 -1.20876334e-01 -2.80589037e-01\n",
      "   1.00919099e-01  1.08515204e-01  1.65791444e-02  2.27384598e-01\n",
      "   4.59784200e-02  1.67377790e-01  1.60468241e-01 -1.19975286e-01\n",
      "  -1.11423281e-01  7.10362609e-02 -3.61700084e-01  1.32337288e-01\n",
      "  -4.68988973e-01 -2.75815174e-01 -2.72524244e-02 -1.70043780e-01\n",
      "  -3.46894923e-01 -2.42046546e-01 -3.94403232e-02 -2.54213024e-02\n",
      "   4.93347875e-03 -8.93401488e-03  6.86809179e-02 -1.28781683e-01\n",
      "   5.67616049e-02 -6.72865726e-02  9.66199466e-02 -7.24732850e-02\n",
      "  -2.15782280e-01 -2.36514435e-01  2.84993920e-01 -1.85568865e-02\n",
      "   8.03568416e-02 -1.72627655e-01 -2.09377087e-01 -5.57153289e-03\n",
      "  -2.10045889e-01 -2.10697824e-01 -1.68697516e-01  2.15972819e-01\n",
      "  -2.13352301e-01  2.27851701e-01  1.34238603e-01  1.53547571e-01\n",
      "   3.99115450e-02 -8.79954454e-02 -2.40023200e-02  1.12650933e-01\n",
      "  -1.03740455e-01 -6.70384137e-02 -1.03895215e-02  8.35539805e-03\n",
      "   1.02209291e-01  2.50076557e-01 -2.09564709e-01  1.76029601e-01\n",
      "   1.75026415e-01  6.95223719e-02 -1.00398668e-02 -2.77945965e-02\n",
      "   2.60246490e-01 -7.36264118e-02  1.86681575e-02 -1.12732491e-02\n",
      "   2.61233552e-01 -9.51885008e-02 -1.24663972e-01 -9.66322664e-02\n",
      "  -2.26128610e-02  8.67178510e-02 -1.48560137e-01  4.83117388e-02\n",
      "  -2.85514566e-02 -2.47773134e-01  2.57319009e-01 -1.44283120e-01\n",
      "   4.10937449e-02 -1.80491794e-01  1.50795111e-01  2.43116796e-01\n",
      "  -2.74318060e-02 -1.86882848e-01  3.13183157e-02  7.35491250e-02\n",
      "  -3.01356121e-01 -3.80726912e-02 -1.31651867e-01 -9.79373505e-02\n",
      "  -2.44113612e-01 -7.30591510e-02  1.45835525e-01  1.60223449e-01\n",
      "   2.68312865e-02 -5.42971448e-02 -2.59559352e-01  1.54369159e-01\n",
      "  -1.12598635e-01  3.14473940e-01 -1.86823111e-01 -1.02418993e-02\n",
      "  -1.28786669e-01 -4.72421822e-02 -3.97209192e-01  8.71482526e-03\n",
      "   1.29087730e-01 -1.12218300e-01  2.73632642e-03  4.89839312e-02\n",
      "   7.00918886e-02 -5.35398241e-02  2.35688993e-01 -4.04632810e-02\n",
      "  -1.65209657e-01 -1.06050446e-01  3.62711833e-02  1.15096765e-01\n",
      "   1.23675597e-01 -1.33460945e-01  7.39434685e-02  5.56662427e-02\n",
      "  -8.51941068e-02  3.36487167e-01 -1.32218818e-04  3.28848230e-01\n",
      "  -5.20536439e-02 -1.78807593e-02  1.42545983e-01 -3.04393670e-01\n",
      "   7.39538966e-02  8.39883524e-02 -1.26684130e-01  1.97446014e-03\n",
      "  -2.01407295e-01  9.15312009e-02  6.18504018e-02  7.70337589e-02\n",
      "   9.20366919e-02  1.62849838e-01 -2.65881604e-01 -3.49414064e-01\n",
      "  -1.60092850e-01 -1.32370419e-01  2.28847683e-02  1.54337845e-01\n",
      "  -6.89253730e-02 -3.52870394e-02 -4.62424418e-02 -7.04194641e-02\n",
      "  -1.23214858e-01  1.37673920e-01 -7.38540720e-02 -1.98690699e-01\n",
      "   4.06518424e-02  5.50806612e-02  2.33349199e-01 -3.16089488e-01\n",
      "  -1.32823348e-03 -1.48756738e-01 -1.52498053e-01 -1.42028391e-01\n",
      "  -3.94203856e-01  8.37272009e-02  1.59273180e-01  6.04824532e-02\n",
      "  -1.56551577e-01  1.36653140e-01 -1.59001579e-02 -2.09495223e-01\n",
      "  -8.64534854e-02  1.64633460e-01 -1.48427037e-01 -4.78899471e-02\n",
      "  -1.44285501e-01 -5.32131601e-02 -3.11337503e-01  3.84335322e-03\n",
      "  -1.65061555e-02  2.06965891e-02  9.86198778e-02  8.91932904e-02\n",
      "  -2.70800520e-02  7.12183092e-02  3.11106765e-01  3.81759857e-03\n",
      "   1.73193024e-01 -1.34557065e-02  1.67797739e-01  9.00663154e-02\n",
      "   8.80778204e-02  7.86712080e-02  1.83436565e-02  8.36286020e-03\n",
      "  -7.65404310e-02 -6.69130952e-03  1.93678058e-02 -3.67949781e-01\n",
      "   1.17416164e-02 -8.58186419e-02 -1.14809213e-02 -6.07936005e-02\n",
      "  -1.64679842e-01  7.75259687e-02 -7.23268541e-02  2.89561332e-01\n",
      "  -2.01238060e-01 -4.83354257e-02  1.53571261e-02  1.64917392e-03\n",
      "   6.99831208e-02  2.64606148e-02 -1.54379534e-01 -7.08982475e-02\n",
      "  -1.25198987e-01  3.45380334e-02 -1.78763691e-01  2.81484823e-01\n",
      "  -9.43294485e-02 -2.05062197e-02 -1.33247215e-01 -2.21977152e-02\n",
      "  -1.82860950e-01 -9.56688830e-02  3.99837084e-02  1.14754226e-01\n",
      "   2.32891560e-01  1.34319664e-01  4.32577400e-02 -5.73615690e-02\n",
      "  -2.03639833e-01  6.71599742e-02  1.95512457e-02  4.22198656e-02\n",
      "   4.93081516e-02  4.15842506e-01  2.10933187e-02  5.50937555e-02]]\n",
      "(76462, 1, 300)\n",
      "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "(76462, 15)\n"
     ]
    }
   ],
   "source": [
    "print(X_train[0])\n",
    "print(X_test.shape)\n",
    "print(y_train[0])\n",
    "print(y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clf = svm.SVC()  # 使用RBF核\n",
    "# clf_res = clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0801 14:52:00.408866 12564 deprecation.py:323] From D:\\Anaconda\\lib\\site-packages\\tensorflow\\python\\ops\\math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "lstm (LSTM)                  (None, 1, 128)            219648    \n",
      "_________________________________________________________________\n",
      "dropout (Dropout)            (None, 1, 128)            0         \n",
      "_________________________________________________________________\n",
      "lstm_1 (LSTM)                (None, 32)                20608     \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 32)                0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 15)                495       \n",
      "=================================================================\n",
      "Total params: 240,751\n",
      "Trainable params: 240,751\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Train on 305846 samples, validate on 76462 samples\n",
      "Epoch 1/10\n",
      "305846/305846 [==============================] - 24s 77us/sample - loss: 1.0633 - accuracy: 0.6894 - val_loss: 0.9326 - val_accuracy: 0.7191\n",
      "Epoch 2/10\n",
      "305846/305846 [==============================] - 20s 65us/sample - loss: 0.9652 - accuracy: 0.7131 - val_loss: 0.9141 - val_accuracy: 0.7250\n",
      "Epoch 3/10\n",
      "305846/305846 [==============================] - 20s 66us/sample - loss: 0.9407 - accuracy: 0.7207 - val_loss: 0.8899 - val_accuracy: 0.7326\n",
      "Epoch 4/10\n",
      "305846/305846 [==============================] - 20s 64us/sample - loss: 0.9257 - accuracy: 0.7249 - val_loss: 0.8790 - val_accuracy: 0.7354\n",
      "Epoch 5/10\n",
      "305846/305846 [==============================] - 20s 66us/sample - loss: 0.9151 - accuracy: 0.7285 - val_loss: 0.8701 - val_accuracy: 0.7363\n",
      "Epoch 6/10\n",
      "305846/305846 [==============================] - 21s 69us/sample - loss: 0.9065 - accuracy: 0.7313 - val_loss: 0.8622 - val_accuracy: 0.7399\n",
      "Epoch 7/10\n",
      "305846/305846 [==============================] - 23s 74us/sample - loss: 0.9002 - accuracy: 0.7320 - val_loss: 0.8597 - val_accuracy: 0.7409\n",
      "Epoch 8/10\n",
      "305846/305846 [==============================] - 22s 73us/sample - loss: 0.8920 - accuracy: 0.7354 - val_loss: 0.8551 - val_accuracy: 0.7418\n",
      "Epoch 9/10\n",
      "305846/305846 [==============================] - 22s 71us/sample - loss: 0.8884 - accuracy: 0.7369 - val_loss: 0.8506 - val_accuracy: 0.7440\n",
      "Epoch 10/10\n",
      "305846/305846 [==============================] - 22s 72us/sample - loss: 0.8826 - accuracy: 0.7388 - val_loss: 0.8457 - val_accuracy: 0.7447\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "\n",
    "model.add(LSTM(128, input_shape=(X_train.shape[1], X_train.shape[2]), return_sequences=True))\n",
    "model.add(Dropout(0.2))\n",
    "\n",
    "model.add(LSTM(32))\n",
    "model.add(Dropout(0.2))\n",
    "\n",
    "model.add(Dense(15, activation='softmax'))\n",
    "\n",
    "model.summary()\n",
    "model.compile(loss=\"categorical_crossentropy\",optimizer='rmsprop',metrics=[\"accuracy\"])\n",
    "\n",
    "model_fit = model.fit(X_train,y_train,batch_size=100,epochs=10,\n",
    "                      validation_data=(X_test,y_test),\n",
    "#                      callbacks=[EarlyStopping(monitor='val_loss',min_delta=0.0001)] ## 當val-loss不再提升時停止訓練\n",
    "                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# max_len = 100\n",
    "# max_words = 5000\n",
    "\n",
    "# ## LSTM模型\n",
    "# inputs = Input(name='inputs',shape=[max_len])\n",
    "# ## Embedding(詞彙表大小,batch大小,每個新聞的詞長)\n",
    "# layer = Embedding(max_words+1,128,input_length=max_len)(inputs)\n",
    "# layer = LSTM(128)(layer)\n",
    "# layer = Dense(128,activation=\"relu\",name=\"FC1\")(layer)\n",
    "# layer = Dropout(0.5)(layer)\n",
    "# layer = Dense(15,activation=\"softmax\",name=\"FC2\")(layer)\n",
    "\n",
    "# model = Model(inputs=inputs,outputs=layer)\n",
    "# model.summary()\n",
    "# model.compile(loss=\"categorical_crossentropy\",optimizer=RMSprop(),metrics=[\"accuracy\"])\n",
    "\n",
    "# model_fit = model.fit(X_train,y_train,batch_size=128,epochs=10,\n",
    "#                       validation_data=(X_test,y_test),\n",
    "#                       callbacks=[EarlyStopping(monitor='val_loss',min_delta=0.0001)] ## 當val-loss不再提升時停止訓練\n",
    "#                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # train_pred = clf_res.predict(X_train)\n",
    "# test_pred = clf_res.predict(X_test)\n",
    "# print(classification_report(y_test, test_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_pred = model.predict(X_test)\n",
    "target_names = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15']\n",
    "# print(classification_report(y_test, test_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([8.6684922e-05, 8.9917833e-04, 1.4269979e-03, 6.2086654e-04,\n",
       "       5.1810939e-02, 6.4793677e-04, 2.3059237e-03, 1.8959485e-03,\n",
       "       9.3046373e-01, 2.3602888e-03, 3.0079446e-04, 3.6617669e-03,\n",
       "       6.4990658e-05, 1.1413632e-03, 2.3125845e-03], dtype=float32)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_pred[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "305846/305846 [==============================] - 16s 52us/sample - loss: 0.8163 - accuracy: 0.7534\n",
      "\n",
      "train Acc: 0.7533955\n"
     ]
    }
   ],
   "source": [
    "result = model.evaluate(X_train, y_train)\n",
    "print ('\\ntrain Acc:', result[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "76462/76462 [==============================] - 4s 54us/sample - loss: 0.8457 - accuracy: 0.7447\n",
      "\n",
      "test Acc: 0.74469674\n"
     ]
    }
   ],
   "source": [
    "result = model.evaluate(X_test, y_test)\n",
    "print ('\\ntest Acc:', result[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
